import torch
import torch.nn.functional as F
import numpy as np


class LinfPGDAttack(object):
    def __init__(self, model):
        self.model = model

    def pgd_perturb(self, x_natural, y, epsilon, alpha, iter):
        x = x_natural.detach()
        x = x + torch.zeros_like(x).uniform_(-epsilon, epsilon)
        for i in range(iter):
            x.requires_grad_()
            with torch.enable_grad():
                logits = self.model(x)
                loss = F.cross_entropy(logits, y)
            grad = torch.autograd.grad(loss, [x])[0]
            x = x.detach() + alpha * torch.sign(grad.detach())
            x = torch.min(torch.max(x, x_natural - epsilon), x_natural + epsilon)
            x = torch.clamp(x, 0, 1)
        return x

    def fgsm_perturb(self, x_natural, y, epsilon, alpha, data_init='zero'):
        x = x_natural.detach()

        if data_init == 'zero':
            delta = torch.zeros_like(x).cuda()
        else:
            delta = torch.zeros_like(x).uniform_(-epsilon, epsilon)
            delta = torch.clamp(delta, 0 - x, 1 - x)

        delta.requires_grad_()
        with torch.enable_grad():
            logits = self.model(x + delta)
            loss = F.cross_entropy(logits, y)
        grad = torch.autograd.grad(loss, [delta])[0]
        delta = torch.clamp(delta + alpha * torch.sign(grad), -epsilon, epsilon)
        delta = torch.clamp(delta, -epsilon, epsilon)
        x = x.detach() + delta.detach()
        x = torch.clamp(x, 0, 1)
        return x